import os
os.environ['CUDA_VISIBLE_DEVICES'] = "0"
import argparse
import json
import torch
from tqdm import tqdm
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.nn.functional as F
from utils import write_img, chw_to_hwc
from datasets.Rain_Dataloader import TestData_for_Rain100L_pad, TestData_for_Rain100H_pad, TestData_for_Rain100L, TestData_for_Rain100H
from datasets.Internet_Dataloader import TestData_for_Internet
from datasets.DDN_Dataloader import TestData_for_DDN, TestData_for_DDN_pad
from datasets.DID_Dataloader import TestData_for_DID, TestData_for_DID_pad
from datasets.SPA_Dataloader import TestData_for_SPA_pad, TestData_for_SPA
from pytorch_msssim import ssim
from utils import *
from utils.utils import *
from skimage.metrics import structural_similarity as compare_ssim
from numpy import *
from models import *

parser = argparse.ArgumentParser()
parser.add_argument('--model', default='FADformer', type=str, help='model name')
parser.add_argument('--num_workers', default=8, type=int, help='number of workers')
parser.add_argument('--data_dir', default='./data/', type=str, help='path to dataset')
parser.add_argument('--save_dir', default='./saved_models/', type=str, help='path to models saving')
parser.add_argument('--result_dir', default='./results/', type=str, help='path to results saving')
parser.add_argument('--exp', default='rain200', type=str, help='experiment setting')
args = parser.parse_args()

def test(val_loader_full, network, result_dir):
    PSNR_full = AverageMeter()
    SSIM_full = AverageMeter()

    torch.cuda.empty_cache()

    network.eval()

    os.makedirs(result_dir, exist_ok=True)

    for batch in val_loader_full:
        source_img = batch['source'].cuda()
        target_img = batch['target'].cuda()
        file_name = batch['filename'][0]

        h, w = source_img.shape[2], source_img.shape[3]

        # Pad the input if not_multiple_of 4
        img_multiple_of = 4
        height, width = source_img.shape[2], source_img.shape[3]
        H, W = ((height + img_multiple_of) // img_multiple_of) * img_multiple_of, (
                (width + img_multiple_of) // img_multiple_of) * img_multiple_of
        padh = H - height if height % img_multiple_of != 0 else 0
        padw = W - width if width % img_multiple_of != 0 else 0
        source_img = F.pad(source_img, (0, padw, 0, padh), mode='reflect')

        with torch.no_grad():
            output = network(source_img).clamp_(0, 1)

        # Unpad the output
        output = output[:, :, :height, :width]

        psnr_full, sim = calculate_psnr_torch(target_img, output)
        PSNR_full.update(psnr_full.item(), source_img.size(0))

        ssim_full = sim
        SSIM_full.update(ssim_full.item(), source_img.size(0))

        # if you dont't need to save output, please comment out
        out_img = chw_to_hwc(output.detach().cpu().squeeze(0).numpy())
        write_img(os.path.join(result_dir, file_name.split('.')[0] + '.png'), out_img)
        # os.rename(os.path.join(result_dir, file_name), os.path.join(result_dir, file_name).split('.')[0] + '_' + str(float(psnr_full)) + '_' + str(float(ssim_full)) + '.png')

    return PSNR_full.avg, SSIM_full.avg


def test_merge(val_loader_full, network, result_dir):
    PSNR_full = AverageMeter()
    SSIM_full = AverageMeter()

    torch.cuda.empty_cache()

    network.eval()

    os.makedirs(result_dir, exist_ok=True)

    for batch in val_loader_full:
        source_img = batch['source'].cuda()
        target_img = batch['target'].cuda()
        file_name = batch['filename'][0]

        B, C, H, W = source_img.shape
        # print(H, W)
        crop_H, crop_W = H - H % 4, W - W % 4

        source1 = source_img[:, :, 0:crop_H, 0:crop_W]
        source2 = source_img[:, :, H - crop_H:H, 0:crop_W]
        source3 = source_img[:, :, H - crop_H:H, W - crop_W:W]
        source4 = source_img[:, :, 0:crop_H, W - crop_W:W]

        map1 = torch.zeros([B, C, H, W]).cuda()
        map2 = torch.zeros([B, C, H, W]).cuda()
        map3 = torch.zeros([B, C, H, W]).cuda()
        map4 = torch.zeros([B, C, H, W]).cuda()

        map1[:, :, 0:crop_H, 0:crop_W] = 1.
        map2[:, :, H - crop_H:H, 0:crop_W] = 1.
        map3[:, :, H - crop_H:H, W - crop_W:W] = 1.
        map4[:, :, 0:crop_H, W - crop_W:W] = 1.

        map = map1 + map2 + map3 + map4

        with torch.no_grad():
            output1 = network(source1).clamp_(0, 1)
            output2 = network(source2).clamp_(0, 1)
            output3 = network(source3).clamp_(0, 1)
            output4 = network(source4).clamp_(0, 1)

        output = torch.zeros([B, C, H, W]).cuda()
        output[:, :, 0:crop_H, 0:crop_W] += output1
        output[:, :, H - crop_H:H, 0:crop_W] += output2
        output[:, :, H - crop_H:H, W - crop_W:W] += output3
        output[:, :, 0:crop_H, W - crop_W:W] += output4

        output = (output / map).clamp_(0, 1)

        psnr_full, sim = calculate_psnr_torch(target_img, output)
        PSNR_full.update(psnr_full.item(), source_img.size(0))

        ssim_full = sim
        SSIM_full.update(ssim_full.item(), source_img.size(0))

        # if you dont't need to save output, please comment out
        out_img = chw_to_hwc(output.detach().cpu().squeeze(0).numpy())
        write_img(os.path.join(result_dir, file_name.split('.')[0] + '.png'), out_img)
        # os.rename(os.path.join(result_dir, file_name), os.path.join(result_dir, file_name).split('.')[0] + '_' + str(float(psnr_full)) + '_' + str(float(ssim_full)) + '.jpg')

    return PSNR_full.avg, SSIM_full.avg


if __name__ == '__main__':

    device_index = [0]
    network = eval(args.model)()
    network = nn.DataParallel(network, device_ids=device_index).cuda()
    network.load_state_dict(torch.load('./pretrain_weights/ddn/FADformer_Rain200H.pth')['state_dict'])

    # Rain200H, Rain200L and SPA-Data can use this code for saving images and testing psnr and ssim
    # DID and DDN should use the Matlib Code to calculate psnr and ssim, so you can output images by this code, and test images in Matlab
    test_dir = '/home/jxy/projects_dir/datasets/Rain100/rain_heavy_test'
    # './datasets/Rain100/rain_data_test_Light'
    # './datasets/Rain100/rain_heavy_test'

    result_dir = './saved_images/Rain200H'

    # option1: test with clip image for faster pred speed
    '''
    test_dataset = TestData_for_Rain100H(4, test_dir_haze)
    test_loader = DataLoader(test_dataset,
                             batch_size=1,
                             shuffle=True,
                             num_workers=args.num_workers,
                             pin_memory=True)

    psnr, ssim = test(test_loader, network, result_dir)
    print(psnr, ssim)
    '''

    # option2: test with full image by merge to reproduce the performance table in our paper

    test_dataset = TestData_for_Rain100H_pad(test_dir)
    test_loader = DataLoader(test_dataset,
                             batch_size=1,
                             shuffle=True,
                             num_workers=args.num_workers,
                             pin_memory=True)

    psnr, ssim = test_merge(test_loader, network, result_dir)
    print(psnr, ssim)
